{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "## 第七章.卷积神经网络.\n",
    "<p>卷积神经网络分为三大层:①卷积层。②池化层。③全连接层</p>"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(9, 75)\n",
      "(90, 75)\n"
     ]
    }
   ],
   "source": [
    "import sys,os\n",
    "from common.util import im2col\n",
    "import numpy as np\n",
    "\n",
    "x1 = np.random.rand(1,3,7,7)\n",
    "col1 = im2col(x1,5,5,stride=1,pad=0)\n",
    "print(col1.shape)\n",
    "\n",
    "x2 = np.random.rand(10,3,7,7)# 10个数据\n",
    "col2 = im2col(x2,5,5,stride=1,pad=0)\n",
    "print(col2.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-10-04T01:58:32.327649400Z",
     "start_time": "2023-10-04T01:58:31.887341700Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 实现卷积层"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "class Convolution:\n",
    "    def __init__(self,W,b,stride=1,pad=0):\n",
    "        self.W = W\n",
    "        self.b = b\n",
    "        self.stride = stride\n",
    "        self.pad = pad\n",
    "\n",
    "    def forward(self,x):\n",
    "        FN,C,FH,FW = self.W.shape\n",
    "        N,C,H,W = x.shape\n",
    "        out_h = int(1+(H+2*self.pad-FH)/self.stride)\n",
    "        out_w = int(1+(W+2*self.pad-FW)/self.stride)\n",
    "\n",
    "        col = im2col(x,FH,FW,self.stride,self.pad)\n",
    "        col_W = self.W.reshape(FN,-1).T # 滤波器的展开\n",
    "        out = np.dot(col,col_W)+self.b\n",
    "\n",
    "        out = out.reshape(N,out_h,out_w,-1).transpose(0,3,1,2)\n",
    "        return out"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-10-04T02:11:23.434868500Z",
     "start_time": "2023-10-04T02:11:23.396875500Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 池化层的实现\n",
    "\n",
    "<p>池化层的实现步骤分三阶段进行</p>\n",
    "<ol>\n",
    "<li>产看输入数据</li>\n",
    "<li>求各行的最大值</li>\n",
    "<li>转换为合适的输出大小</li>\n",
    "</ol>"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "class Pooling:\n",
    "    def __init__(self,pool_h,pool_w,stride=1,pad=6):\n",
    "        self.pool_h = pool_h\n",
    "        self.pool_w = pool_w\n",
    "        self.stride = stride\n",
    "        self.pad = pad\n",
    "\n",
    "    def forward(self,x):\n",
    "        N,C,H,W = x.shape\n",
    "        out_h = int(1+(H-self.pool_h)/self.stride)\n",
    "        out_w = int(1+(W-self.pool_w)/self.stride)\n",
    "        # 展开1\n",
    "        col = im2col(x,self.pool_h,self.pool_w,self.stride,self.pad)\n",
    "        col = col.reshape(-1,self.pool_h*self.pool_w)\n",
    "        # 最大值2\n",
    "        out = np.max(col,axis=1)\n",
    "        # 转换\n",
    "        out = out.reshape(N,out_h,out_w,C).transpose(0,3,1,2)\n",
    "        return out\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-10-04T02:35:18.866401100Z",
     "start_time": "2023-10-04T02:35:18.819378900Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### CNN的实现"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "outputs": [],
   "source": [
    "import pickle\n",
    "from collections import OrderedDict\n",
    "from common.layers import Relu,Affine,SoftmaxWithLoss,Pooling,Convolution\n",
    "\n",
    "class SimpleConvNet:\n",
    "    def __init__(self,input_dim=(1,28,28),\n",
    "                 conv_param={'filter_num':30,'filter_size':5,'pad':0,'stride':1},\n",
    "                 hidden_size = 100,output_size=10,weight_init_std=0.01\n",
    "                 ):\n",
    "        '''\n",
    "        :param input_dim:输入数据的维度:(通道,高,长)\n",
    "        :param conv_param: 卷积层的超参数(字典)\n",
    "        :param hidden_size: 隐藏层(全连接的数量)\n",
    "        :param output_size: 输出层(全连接)的神经元数量\n",
    "        :param weight_init_std:初始化时权重的标准差\n",
    "        '''\n",
    "        filter_num = conv_param['filter_num']\n",
    "        filter_size = conv_param['filter_size']\n",
    "        filter_pad = conv_param['pad']\n",
    "        filter_stride = conv_param['stride']\n",
    "        input_size = input_dim[1]\n",
    "        conv_output_size = (input_size-filter_size+2*filter_pad)/filter_stride +1\n",
    "        pool_output_size = int(filter_num*(conv_output_size/2)*(conv_output_size/2))\n",
    "\n",
    "        # 权重参数的初始化\n",
    "        self.params = {}\n",
    "        self.params['W1'] = weight_init_std*np.random.randn(filter_num,input_dim[0],filter_size,filter_size)\n",
    "        self.params['b1'] = np.zeros(filter_num)\n",
    "        self.params['W2'] = weight_init_std*np.random.randn(pool_output_size,hidden_size)\n",
    "        self.params['b2'] = np.zeros(hidden_size)\n",
    "        self.params['W3'] = weight_init_std* np.random.randn(hidden_size,output_size)\n",
    "        self.params['b3'] = np.zeros(output_size)\n",
    "\n",
    "        self.layers = OrderedDict()\n",
    "        self.layers['Conv1'] =Convolution(self.params['W1'],\n",
    "                                          self.params['b1'],\n",
    "                                          conv_param['stride'],\n",
    "                                          conv_param['pad']\n",
    "                                          )\n",
    "        self.layers['Relu1'] = Relu()\n",
    "        self.layers['Pool1'] = Pooling(pool_h=2,pool_w=2,stride=2)\n",
    "        self.layers['Affine1'] = Affine(self.params['W2'],self.params['b2'])\n",
    "        self.layers['Relu2'] = Relu()\n",
    "        self.layers['Affine2'] = Affine(self.params['W3'],self.params['b3'])\n",
    "        self.lastLayer = SoftmaxWithLoss()\n",
    "\n",
    "    def predict(self,x):\n",
    "        for layer in self.layers.values():\n",
    "            x = layer.forward(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "    def accuracy(self, x, t, batch_size=100):\n",
    "        if t.ndim != 1 : t = np.argmax(t, axis=1)\n",
    "\n",
    "        acc = 0.0\n",
    "\n",
    "        for i in range(int(x.shape[0] / batch_size)):\n",
    "            tx = x[i*batch_size:(i+1)*batch_size]\n",
    "            tt = t[i*batch_size:(i+1)*batch_size]\n",
    "            y = self.predict(tx)\n",
    "            y = np.argmax(y, axis=1)\n",
    "            acc += np.sum(y == tt)\n",
    "\n",
    "        return acc / x.shape[0]\n",
    "\n",
    "    def save_params(self, file_name=\"params.pkl\"):\n",
    "        params = {}\n",
    "        for key, val in self.params.items():\n",
    "            params[key] = val\n",
    "        with open(file_name, 'wb') as f:\n",
    "            pickle.dump(params, f)\n",
    "\n",
    "    def load_params(self, file_name=\"params.pkl\"):\n",
    "        with open(file_name, 'rb') as f:\n",
    "            params = pickle.load(f)\n",
    "        for key, val in params.items():\n",
    "            self.params[key] = val\n",
    "\n",
    "        for i, key in enumerate(['Conv1', 'Affine1', 'Affine2']):\n",
    "            self.layers[key].W = self.params['W' + str(i+1)]\n",
    "            self.layers[key].b = self.params['b' + str(i+1)]\n",
    "\n",
    "    def loss(self,x,t):\n",
    "        y = self.predict(x)\n",
    "        return self.lastLayer.forward(y,t)\n",
    "\n",
    "    def gradient(self,x,t):\n",
    "        # forward\n",
    "        self.loss(x,t)\n",
    "        # backward\n",
    "        dout = 1\n",
    "        dout = self.lastLayer.backward(dout)\n",
    "\n",
    "        layers = list(self.layers.values())\n",
    "        layers.reverse()\n",
    "        for layer in layers:\n",
    "            dout = layer.backward(dout)\n",
    "\n",
    "        # 设定\n",
    "        grads = {}\n",
    "        grads['W1'] = self.layers['Conv1'].dW\n",
    "        grads['b1'] = self.layers['Conv1'].db\n",
    "        grads['W2'] = self.layers['Affine1'].dW\n",
    "        grads['b2'] = self.layers['Affine1'].db\n",
    "        grads['W3'] = self.layers['Affine2'].dW\n",
    "        grads['b3'] = self.layers['Affine2'].db\n",
    "\n",
    "        return grads"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-10-04T04:30:24.961546Z",
     "start_time": "2023-10-04T04:30:24.936546500Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:2.299143113831111\n",
      "=== epoch:1, train acc:0.112, test acc:0.128 ===\n",
      "train loss:2.2952922368260205\n",
      "train loss:2.2907960386721418\n",
      "train loss:2.2817013644665995\n",
      "train loss:2.2688686891402035\n",
      "train loss:2.2549414982756573\n",
      "train loss:2.2429060097481437\n",
      "train loss:2.2174980785785308\n",
      "train loss:2.213592791509185\n",
      "train loss:2.180652327971934\n",
      "train loss:2.1232016512613177\n",
      "train loss:2.1020957496181176\n",
      "train loss:2.076849736648811\n",
      "train loss:2.0079407323307716\n",
      "train loss:1.943382866260245\n",
      "train loss:1.8865549062798872\n",
      "train loss:1.8565880627189144\n",
      "train loss:1.746621452048782\n",
      "train loss:1.621099721417907\n",
      "train loss:1.5407261142901054\n",
      "train loss:1.553876747356715\n",
      "train loss:1.3469784478402236\n",
      "train loss:1.3162512875497066\n",
      "train loss:1.2393306486779736\n",
      "train loss:1.2136225532934781\n",
      "train loss:1.0349522560060307\n",
      "train loss:1.0700380970150838\n",
      "train loss:1.1152833770995072\n",
      "train loss:0.9814435554689901\n",
      "train loss:1.047367881180291\n",
      "train loss:0.7684692070387146\n",
      "train loss:0.7396630582034686\n",
      "train loss:0.6965426682262248\n",
      "train loss:0.8037335184793623\n",
      "train loss:0.7644508046917128\n",
      "train loss:0.9354501210359558\n",
      "train loss:0.6123663785076304\n",
      "train loss:0.6442141452122857\n",
      "train loss:0.6845878322733541\n",
      "train loss:0.7924188255231306\n",
      "train loss:0.6416068453052705\n",
      "train loss:0.8678234302359704\n",
      "train loss:0.6356275508423976\n",
      "train loss:0.47063120415704424\n",
      "train loss:0.7129202447040087\n",
      "train loss:0.3899344104530196\n",
      "train loss:0.49432248370840237\n",
      "train loss:0.5093966160278022\n",
      "train loss:0.5302437536578929\n",
      "train loss:0.4682954725980701\n",
      "train loss:0.4865224144887111\n",
      "=== epoch:2, train acc:0.828, test acc:0.818 ===\n",
      "train loss:0.5024037114454533\n",
      "train loss:0.5407844835302449\n",
      "train loss:0.4490949302221485\n",
      "train loss:0.40007544573076176\n",
      "train loss:0.4470633667962828\n",
      "train loss:0.42328310732910474\n",
      "train loss:0.57218860712643\n",
      "train loss:0.46494637083293017\n",
      "train loss:0.6383860808179223\n",
      "train loss:0.42205863585927267\n",
      "train loss:0.38983315809995256\n",
      "train loss:0.5736867449150578\n",
      "train loss:0.35497724836943917\n",
      "train loss:0.3479518309290959\n",
      "train loss:0.6414539618041789\n",
      "train loss:0.41729292662188966\n",
      "train loss:0.5215375996785786\n",
      "train loss:0.4941483741604764\n",
      "train loss:0.4332736455135162\n",
      "train loss:0.3240180972078002\n",
      "train loss:0.2913033827328298\n",
      "train loss:0.4063600998086413\n",
      "train loss:0.40909858798026827\n",
      "train loss:0.37955922710849377\n",
      "train loss:0.5099509494773201\n",
      "train loss:0.3812616269741798\n",
      "train loss:0.42577044196022656\n",
      "train loss:0.4085145828393927\n",
      "train loss:0.4218602845327475\n",
      "train loss:0.21411546485519412\n",
      "train loss:0.4152164425978252\n",
      "train loss:0.2798533541535117\n",
      "train loss:0.2786162583879571\n",
      "train loss:0.3074597765349542\n",
      "train loss:0.37805031343035145\n",
      "train loss:0.3901858800109562\n",
      "train loss:0.2840447566483156\n",
      "train loss:0.5171617469492802\n",
      "train loss:0.24574821257543744\n",
      "train loss:0.22193456611722687\n",
      "train loss:0.5638107881528037\n",
      "train loss:0.3593984586779396\n",
      "train loss:0.495163268882989\n",
      "train loss:0.2664204072994977\n",
      "train loss:0.3460786474337371\n",
      "train loss:0.2810947183922839\n",
      "train loss:0.2759896993358664\n",
      "train loss:0.3400471099345024\n",
      "train loss:0.38543281533178936\n",
      "train loss:0.2879163013086329\n",
      "=== epoch:3, train acc:0.875, test acc:0.854 ===\n",
      "train loss:0.2917582595449524\n",
      "train loss:0.4750508474687009\n",
      "train loss:0.3654510298606696\n",
      "train loss:0.23762484195125033\n",
      "train loss:0.30640290047491914\n",
      "train loss:0.3495378854747803\n",
      "train loss:0.3954279854262577\n",
      "train loss:0.29061274988997804\n",
      "train loss:0.2509731492464822\n",
      "train loss:0.24198838347780885\n",
      "train loss:0.3940350426854653\n",
      "train loss:0.2879139862766537\n",
      "train loss:0.2572520743201624\n",
      "train loss:0.39684559359604427\n",
      "train loss:0.1710357170400979\n",
      "train loss:0.39082618352641374\n",
      "train loss:0.2834468765672944\n",
      "train loss:0.26241549277540765\n",
      "train loss:0.23644112345824211\n",
      "train loss:0.32050387743539765\n",
      "train loss:0.2504830930792805\n",
      "train loss:0.34416316274044684\n",
      "train loss:0.3691156057162112\n",
      "train loss:0.2437548830448493\n",
      "train loss:0.24144628363447104\n",
      "train loss:0.25592822034968976\n",
      "train loss:0.2814794459603268\n",
      "train loss:0.24336489162439243\n",
      "train loss:0.4002820015534181\n",
      "train loss:0.3562117067718819\n",
      "train loss:0.24048464623737428\n",
      "train loss:0.21946149732121933\n",
      "train loss:0.2873850544636349\n",
      "train loss:0.31152232289810006\n",
      "train loss:0.2983396692622337\n",
      "train loss:0.20942308184645167\n",
      "train loss:0.25618991598021484\n",
      "train loss:0.21243328871800163\n",
      "train loss:0.3144059875006741\n",
      "train loss:0.17864494284981647\n",
      "train loss:0.17981666672882657\n",
      "train loss:0.34592135647331324\n",
      "train loss:0.21283467118002464\n",
      "train loss:0.29911773329693453\n",
      "train loss:0.3133698618828004\n",
      "train loss:0.26691149700753874\n",
      "train loss:0.15667362711153068\n",
      "train loss:0.44504198229069675\n",
      "train loss:0.3319721511698168\n",
      "train loss:0.30214312434589846\n",
      "=== epoch:4, train acc:0.9, test acc:0.882 ===\n",
      "train loss:0.257457705357776\n",
      "train loss:0.17596949721748092\n",
      "train loss:0.40874335206258217\n",
      "train loss:0.22790793745307872\n",
      "train loss:0.2824970595672232\n",
      "train loss:0.3523602804146046\n",
      "train loss:0.24058234389698274\n",
      "train loss:0.2723343646248853\n",
      "train loss:0.1820832171390182\n",
      "train loss:0.13682428504202354\n",
      "train loss:0.31713832987757895\n",
      "train loss:0.26913899320017287\n",
      "train loss:0.19206344448635895\n",
      "train loss:0.3872486531121456\n",
      "train loss:0.20109994943808282\n",
      "train loss:0.1387228642814377\n",
      "train loss:0.18933744397948954\n",
      "train loss:0.2603495461376498\n",
      "train loss:0.227960884596621\n",
      "train loss:0.15276861980292483\n",
      "train loss:0.2289911405724956\n",
      "train loss:0.383335522031549\n",
      "train loss:0.21596317899640705\n",
      "train loss:0.2866754665205715\n",
      "train loss:0.16698808365530562\n",
      "train loss:0.18282213602622008\n",
      "train loss:0.13208337213415355\n",
      "train loss:0.23497902554709466\n",
      "train loss:0.2214158856438391\n",
      "train loss:0.17453241234567013\n",
      "train loss:0.26311926967419635\n",
      "train loss:0.2999576830631843\n",
      "train loss:0.2352194383885969\n",
      "train loss:0.1349790198236749\n",
      "train loss:0.23946162231162918\n",
      "train loss:0.19134027556126937\n",
      "train loss:0.2993860126203588\n",
      "train loss:0.28791215024918676\n",
      "train loss:0.47534008692272034\n",
      "train loss:0.3188825219066938\n",
      "train loss:0.19562225488168428\n",
      "train loss:0.18902180034600957\n",
      "train loss:0.2534217936942096\n",
      "train loss:0.16723578151488097\n",
      "train loss:0.15816423145606623\n",
      "train loss:0.2941879934778783\n",
      "train loss:0.22994942214180522\n",
      "train loss:0.4604143461856755\n",
      "train loss:0.33074074179843654\n",
      "train loss:0.13867695766775776\n",
      "=== epoch:5, train acc:0.914, test acc:0.908 ===\n",
      "train loss:0.17335716759797182\n",
      "train loss:0.16661988621329077\n",
      "train loss:0.16967747408050055\n",
      "train loss:0.3293733062901353\n",
      "train loss:0.23154912913669734\n",
      "train loss:0.2347124152140337\n",
      "train loss:0.2627915740672021\n",
      "train loss:0.22676807798730483\n",
      "train loss:0.1419184622421613\n",
      "train loss:0.249747427991637\n",
      "train loss:0.35148259523090253\n",
      "train loss:0.235486109469636\n",
      "train loss:0.18439394273767995\n",
      "train loss:0.24581787107983724\n",
      "train loss:0.27764039067100665\n",
      "train loss:0.4322032331902267\n",
      "train loss:0.16690305573647318\n",
      "train loss:0.3755613114405206\n",
      "train loss:0.07318332739498118\n",
      "train loss:0.2814184219631675\n",
      "train loss:0.3115882058507778\n",
      "train loss:0.25384185973446693\n",
      "train loss:0.2368030880440445\n",
      "train loss:0.22816619376967207\n",
      "train loss:0.20193742595079048\n",
      "train loss:0.25940169819640646\n",
      "train loss:0.18116405769789687\n",
      "train loss:0.07960923246074193\n",
      "train loss:0.1645249608366051\n",
      "train loss:0.1513953745859174\n",
      "train loss:0.35051388673135486\n",
      "train loss:0.1522448317322129\n",
      "train loss:0.23298882416401615\n",
      "train loss:0.11357407457697745\n",
      "train loss:0.22483142962106092\n",
      "train loss:0.16157304709536924\n",
      "train loss:0.379376002759639\n",
      "train loss:0.24109861860285786\n",
      "train loss:0.111727203528494\n",
      "train loss:0.11288916971565502\n",
      "train loss:0.283980872497846\n",
      "train loss:0.13469380067869535\n",
      "train loss:0.2099551806604141\n",
      "train loss:0.16888865456677318\n",
      "train loss:0.2350779702698637\n",
      "train loss:0.2442709245847583\n",
      "train loss:0.2890487506765111\n",
      "train loss:0.2371979267837691\n",
      "train loss:0.15388616648809114\n",
      "train loss:0.20015839262263\n",
      "=== epoch:6, train acc:0.921, test acc:0.907 ===\n",
      "train loss:0.18514648120383903\n",
      "train loss:0.2156918315754456\n",
      "train loss:0.17785161970716992\n",
      "train loss:0.23239541810960174\n",
      "train loss:0.16404057838095945\n",
      "train loss:0.24956806899678557\n",
      "train loss:0.34654363528591875\n",
      "train loss:0.11289966710852173\n",
      "train loss:0.17055816879917118\n",
      "train loss:0.20359867635772733\n",
      "train loss:0.1523443861592739\n",
      "train loss:0.2721291980937557\n",
      "train loss:0.2878141156253983\n",
      "train loss:0.2582970388530781\n",
      "train loss:0.22957718895208315\n",
      "train loss:0.1984699901604556\n",
      "train loss:0.11524182444786898\n",
      "train loss:0.2511485452904381\n",
      "train loss:0.1949519897668639\n",
      "train loss:0.22436168614447474\n",
      "train loss:0.2952718714778646\n",
      "train loss:0.20130178277954658\n",
      "train loss:0.16829290193220964\n",
      "train loss:0.12516021149844966\n",
      "train loss:0.17362679122362643\n",
      "train loss:0.18636600836258876\n",
      "train loss:0.22588288348037874\n",
      "train loss:0.22300643270590295\n",
      "train loss:0.2059810264598623\n",
      "train loss:0.26825817141840813\n",
      "train loss:0.2512110251737498\n",
      "train loss:0.11476383243363744\n",
      "train loss:0.1757162737218959\n",
      "train loss:0.1479999476705242\n",
      "train loss:0.20534742690607996\n",
      "train loss:0.1726894478265429\n",
      "train loss:0.1728240990010664\n",
      "train loss:0.10656667533171085\n",
      "train loss:0.12874681277140776\n",
      "train loss:0.12332216700496276\n",
      "train loss:0.1105633858376215\n",
      "train loss:0.08984051417713122\n",
      "train loss:0.14345812874570144\n",
      "train loss:0.1839638953602397\n",
      "train loss:0.20090450080319908\n",
      "train loss:0.202618428551881\n",
      "train loss:0.10496053053629184\n",
      "train loss:0.14839720525253203\n",
      "train loss:0.16972441815947825\n",
      "train loss:0.20826822532059164\n",
      "=== epoch:7, train acc:0.941, test acc:0.918 ===\n",
      "train loss:0.1609719198629407\n",
      "train loss:0.18792325475868804\n",
      "train loss:0.24293943386896177\n",
      "train loss:0.13841614647922468\n",
      "train loss:0.22816970344692783\n",
      "train loss:0.13175546549241932\n",
      "train loss:0.13078420611588815\n",
      "train loss:0.19664087933283267\n",
      "train loss:0.14261061850627021\n",
      "train loss:0.15137949963939645\n",
      "train loss:0.20464247728009555\n",
      "train loss:0.14859435175321276\n",
      "train loss:0.24205734332325046\n",
      "train loss:0.1805261490319086\n",
      "train loss:0.19893048223773074\n",
      "train loss:0.26016045492269246\n",
      "train loss:0.17541428064788295\n",
      "train loss:0.18123275907950098\n",
      "train loss:0.14959172904893214\n",
      "train loss:0.1244737907050365\n",
      "train loss:0.10773257929163853\n",
      "train loss:0.16255020467534073\n",
      "train loss:0.1301827377027935\n",
      "train loss:0.13549328516955594\n",
      "train loss:0.06580241188633501\n",
      "train loss:0.22356676492220207\n",
      "train loss:0.19189540048111461\n",
      "train loss:0.09608842794504648\n",
      "train loss:0.1316206068657877\n",
      "train loss:0.1383512115015249\n",
      "train loss:0.1533049851617367\n",
      "train loss:0.21572447666236605\n",
      "train loss:0.08660677030681853\n",
      "train loss:0.1385304967643453\n",
      "train loss:0.13971947233526347\n",
      "train loss:0.21368869578404998\n",
      "train loss:0.08927205608541725\n",
      "train loss:0.18688265068009471\n",
      "train loss:0.30923592286380847\n",
      "train loss:0.14435770380275684\n",
      "train loss:0.17283869733950896\n",
      "train loss:0.073200828713002\n",
      "train loss:0.1020037803178797\n",
      "train loss:0.13492317278245256\n",
      "train loss:0.17402371251015275\n",
      "train loss:0.23142421896543006\n",
      "train loss:0.18129684848200117\n",
      "train loss:0.08044460195721313\n",
      "train loss:0.10826866638344171\n",
      "train loss:0.18100253519229306\n",
      "=== epoch:8, train acc:0.932, test acc:0.916 ===\n",
      "train loss:0.20060907698283273\n",
      "train loss:0.133520032044572\n",
      "train loss:0.17739064392271975\n",
      "train loss:0.14757659746595844\n",
      "train loss:0.16986846955024898\n",
      "train loss:0.13660175637029723\n",
      "train loss:0.16942509309930948\n",
      "train loss:0.17521058335930723\n",
      "train loss:0.11551169053111084\n",
      "train loss:0.17239329228630282\n",
      "train loss:0.09132218352329247\n",
      "train loss:0.17901217742285058\n",
      "train loss:0.15488483050246807\n",
      "train loss:0.15272882614471678\n",
      "train loss:0.1506614095756315\n",
      "train loss:0.09078130512685999\n",
      "train loss:0.15248039335644425\n",
      "train loss:0.11060686623925471\n",
      "train loss:0.16202266277936908\n",
      "train loss:0.13858344998739675\n",
      "train loss:0.1964422651457615\n",
      "train loss:0.14459637998167538\n",
      "train loss:0.12011986269343118\n",
      "train loss:0.11360760306202629\n",
      "train loss:0.19053072525519063\n",
      "train loss:0.12146716714001987\n",
      "train loss:0.1029962241428657\n",
      "train loss:0.19037850122561342\n",
      "train loss:0.19483454014155688\n",
      "train loss:0.26111092192006924\n",
      "train loss:0.1079290836451926\n",
      "train loss:0.1461266275807076\n",
      "train loss:0.10636516588628896\n",
      "train loss:0.14514937759119823\n",
      "train loss:0.09278503961913007\n",
      "train loss:0.09460048517813846\n",
      "train loss:0.17989164246513487\n",
      "train loss:0.059595278756450994\n",
      "train loss:0.10496302605563662\n",
      "train loss:0.05489069857837594\n",
      "train loss:0.07804752409050462\n",
      "train loss:0.11178370898089611\n",
      "train loss:0.15862638111620192\n",
      "train loss:0.15738124088579072\n",
      "train loss:0.13057792940434312\n",
      "train loss:0.08846789984903505\n",
      "train loss:0.2873630134788098\n",
      "train loss:0.113838703926676\n",
      "train loss:0.12812483676490452\n",
      "train loss:0.1872121313786479\n",
      "=== epoch:9, train acc:0.957, test acc:0.933 ===\n",
      "train loss:0.09262585653599384\n",
      "train loss:0.24393678890554052\n",
      "train loss:0.1403953585128687\n",
      "train loss:0.10596533048127028\n",
      "train loss:0.10700116902258724\n",
      "train loss:0.09536134315594508\n",
      "train loss:0.16957569157094451\n",
      "train loss:0.09870310916301049\n",
      "train loss:0.17161748025604798\n",
      "train loss:0.11252784531178449\n",
      "train loss:0.09192449545996846\n",
      "train loss:0.10049217360795691\n",
      "train loss:0.08312411133109411\n",
      "train loss:0.14497507943265694\n",
      "train loss:0.26345266579518145\n",
      "train loss:0.14271637827026637\n",
      "train loss:0.13221254561780754\n",
      "train loss:0.07972540529393135\n",
      "train loss:0.07876332532458952\n",
      "train loss:0.1994780342366004\n",
      "train loss:0.09029032903715699\n",
      "train loss:0.09710149623657172\n",
      "train loss:0.15546252805520652\n",
      "train loss:0.1268663821847025\n",
      "train loss:0.13952207931564886\n",
      "train loss:0.09096080804729356\n",
      "train loss:0.07255490571073003\n",
      "train loss:0.07316993872099693\n",
      "train loss:0.11598624784933749\n",
      "train loss:0.07035683555483913\n",
      "train loss:0.11712751713759577\n",
      "train loss:0.14076161814954566\n",
      "train loss:0.08708990772179108\n",
      "train loss:0.07565238946321681\n",
      "train loss:0.1633377126302255\n",
      "train loss:0.1312648189668677\n",
      "train loss:0.06804830317832608\n",
      "train loss:0.08178336775541474\n",
      "train loss:0.05876692751667945\n",
      "train loss:0.0821168860985159\n",
      "train loss:0.11030814272021948\n",
      "train loss:0.06804882605221037\n",
      "train loss:0.14479014260570655\n",
      "train loss:0.14268905511012409\n",
      "train loss:0.15549891545137728\n",
      "train loss:0.0617383879471912\n",
      "train loss:0.03622707282214508\n",
      "train loss:0.1604206140737243\n",
      "train loss:0.1502559963007728\n",
      "train loss:0.0704546872662581\n",
      "=== epoch:10, train acc:0.958, test acc:0.941 ===\n",
      "train loss:0.07674602780576265\n",
      "train loss:0.06406347743269461\n",
      "train loss:0.052295524198261084\n",
      "train loss:0.130583673914501\n",
      "train loss:0.06038670086699391\n",
      "train loss:0.05947895202615057\n",
      "train loss:0.16863709996843973\n",
      "train loss:0.09373538245858727\n",
      "train loss:0.09275024789252308\n",
      "train loss:0.1374854628396091\n",
      "train loss:0.08998766579007103\n",
      "train loss:0.13501223829864406\n",
      "train loss:0.06351224441267361\n",
      "train loss:0.06352280639280605\n",
      "train loss:0.15784960608988788\n",
      "train loss:0.0639156519758925\n",
      "train loss:0.0929444771018086\n",
      "train loss:0.08722272892301156\n",
      "train loss:0.04160617316922602\n",
      "train loss:0.1230462850465091\n",
      "train loss:0.041902304982636564\n",
      "train loss:0.20064169806961235\n",
      "train loss:0.05323666114756076\n",
      "train loss:0.1259211502966087\n",
      "train loss:0.0458585766629703\n",
      "train loss:0.09257389929844141\n",
      "train loss:0.11102022069873843\n",
      "train loss:0.08681166096487343\n",
      "train loss:0.04903488945941605\n",
      "train loss:0.08055752049636732\n",
      "train loss:0.040961869076320136\n",
      "train loss:0.08361265399518786\n",
      "train loss:0.046760187200936644\n",
      "train loss:0.10679474693543042\n",
      "train loss:0.0876234708133648\n",
      "train loss:0.05810814921355531\n",
      "train loss:0.06757897763652142\n",
      "train loss:0.08073074276568747\n",
      "train loss:0.06519795947887194\n",
      "train loss:0.1165390106022556\n",
      "train loss:0.1332403281125067\n",
      "train loss:0.16734235741083914\n",
      "train loss:0.04657506658232619\n",
      "train loss:0.04503771994959965\n",
      "train loss:0.05064475587153073\n",
      "train loss:0.0667141292194634\n",
      "train loss:0.10536555903014076\n",
      "train loss:0.056531398372858865\n",
      "train loss:0.10478391506150357\n",
      "train loss:0.09242128289708623\n",
      "=== epoch:11, train acc:0.968, test acc:0.946 ===\n",
      "train loss:0.10270646278113177\n",
      "train loss:0.06333296945022254\n",
      "train loss:0.08631396964896171\n",
      "train loss:0.04834857359032127\n",
      "train loss:0.07897490048540327\n",
      "train loss:0.04871023375173193\n",
      "train loss:0.07894893336946412\n",
      "train loss:0.0497511868035226\n",
      "train loss:0.05122325866203825\n",
      "train loss:0.1320685654413176\n",
      "train loss:0.032940934747182675\n",
      "train loss:0.1616113162371393\n",
      "train loss:0.12187485601621603\n",
      "train loss:0.04306503473961405\n",
      "train loss:0.15429976958424546\n",
      "train loss:0.08599067819430228\n",
      "train loss:0.14305122550631844\n",
      "train loss:0.039232695478579764\n",
      "train loss:0.048475839971323216\n",
      "train loss:0.05790641882050221\n",
      "train loss:0.06929423309071522\n",
      "train loss:0.10569556285713878\n",
      "train loss:0.08400278040849395\n",
      "train loss:0.08420982179558935\n",
      "train loss:0.06067901365327593\n",
      "train loss:0.08689839612220952\n",
      "train loss:0.04150632306707762\n",
      "train loss:0.05262425696926345\n",
      "train loss:0.10325416977921599\n",
      "train loss:0.07229824327360468\n",
      "train loss:0.1097283661453384\n",
      "train loss:0.08518450049419866\n",
      "train loss:0.07663624347019359\n",
      "train loss:0.06184733429560332\n",
      "train loss:0.1052703545560216\n",
      "train loss:0.09180069896280824\n",
      "train loss:0.11477300202596044\n",
      "train loss:0.10244128476365727\n",
      "train loss:0.12224084919950334\n",
      "train loss:0.045840641219235895\n",
      "train loss:0.09875863953326619\n",
      "train loss:0.07617884457940098\n",
      "train loss:0.08577214419484658\n",
      "train loss:0.09465071005965797\n",
      "train loss:0.09054756202485664\n",
      "train loss:0.09699051449976254\n",
      "train loss:0.03862659664036287\n",
      "train loss:0.045877863205153586\n",
      "train loss:0.07552831175343629\n",
      "train loss:0.11210499889207598\n",
      "=== epoch:12, train acc:0.967, test acc:0.947 ===\n",
      "train loss:0.09705703592248502\n",
      "train loss:0.06096602097374666\n",
      "train loss:0.08861852677470469\n",
      "train loss:0.06339705396213749\n",
      "train loss:0.05020222222060551\n",
      "train loss:0.09699322253161259\n",
      "train loss:0.04132872032362398\n",
      "train loss:0.11283498399042129\n",
      "train loss:0.11920685270166781\n",
      "train loss:0.06988040223704144\n",
      "train loss:0.06567433934085434\n",
      "train loss:0.09242328866554284\n",
      "train loss:0.056337818802725446\n",
      "train loss:0.059267644628480595\n",
      "train loss:0.11376997522496538\n",
      "train loss:0.06278408320692642\n",
      "train loss:0.05157396835275738\n",
      "train loss:0.05529717729691235\n",
      "train loss:0.05143328374682517\n",
      "train loss:0.03899531704922439\n",
      "train loss:0.04695670885025721\n",
      "train loss:0.1450798841031807\n",
      "train loss:0.10778585303896822\n",
      "train loss:0.11126380673349517\n",
      "train loss:0.02774782875849695\n",
      "train loss:0.056723334736425324\n",
      "train loss:0.05977897760793023\n",
      "train loss:0.056072881475640726\n",
      "train loss:0.11606176484456071\n",
      "train loss:0.059315307378861724\n",
      "train loss:0.060228080547888696\n",
      "train loss:0.09805029577696323\n",
      "train loss:0.07274519288882712\n",
      "train loss:0.2345132199327742\n",
      "train loss:0.07103011399634458\n",
      "train loss:0.07417677682684388\n",
      "train loss:0.059011305539905766\n",
      "train loss:0.12670334530257205\n",
      "train loss:0.05272544768452768\n",
      "train loss:0.10144583267863823\n",
      "train loss:0.02560159910619532\n",
      "train loss:0.12959107005416326\n",
      "train loss:0.06727544732649085\n",
      "train loss:0.09984728532879984\n",
      "train loss:0.06374915535653226\n",
      "train loss:0.13575840403917058\n",
      "train loss:0.18623617650905688\n",
      "train loss:0.09309364817694682\n",
      "train loss:0.10436078613306829\n",
      "train loss:0.06596811206325247\n",
      "=== epoch:13, train acc:0.971, test acc:0.947 ===\n",
      "train loss:0.03857140808314377\n",
      "train loss:0.10035787069403977\n",
      "train loss:0.1280689020376933\n",
      "train loss:0.0756529985150842\n",
      "train loss:0.0925774898363379\n",
      "train loss:0.049688655040586255\n",
      "train loss:0.12082375145119316\n",
      "train loss:0.12111267221647726\n",
      "train loss:0.0928191811221251\n",
      "train loss:0.03697652702009753\n",
      "train loss:0.0685962907558876\n",
      "train loss:0.07064756184303173\n",
      "train loss:0.10034367148945153\n",
      "train loss:0.06893083786692084\n",
      "train loss:0.0939730130071872\n",
      "train loss:0.06977483919597617\n",
      "train loss:0.03180708395281301\n",
      "train loss:0.06727352734297587\n",
      "train loss:0.07407076591751449\n",
      "train loss:0.04194988887972091\n",
      "train loss:0.0754439487440803\n",
      "train loss:0.05557345882384168\n",
      "train loss:0.05719087150682471\n",
      "train loss:0.09936613782287623\n",
      "train loss:0.02406834387313046\n",
      "train loss:0.05126913713027448\n",
      "train loss:0.07676019137181286\n",
      "train loss:0.06516161732196463\n",
      "train loss:0.034761922549907814\n",
      "train loss:0.047735485875838556\n",
      "train loss:0.08253982515753146\n",
      "train loss:0.018298544261497948\n",
      "train loss:0.09289037892914985\n",
      "train loss:0.06248814214783081\n",
      "train loss:0.05193666782250061\n",
      "train loss:0.04820854215552318\n",
      "train loss:0.10321679522608009\n",
      "train loss:0.03342246329709425\n",
      "train loss:0.088800865753506\n",
      "train loss:0.04385580327096568\n",
      "train loss:0.07617207285600301\n",
      "train loss:0.04111996889198226\n",
      "train loss:0.02325171740227459\n",
      "train loss:0.06251521395102937\n",
      "train loss:0.06164980204240551\n",
      "train loss:0.09311132816347248\n",
      "train loss:0.03202810699409368\n",
      "train loss:0.02860716560424768\n",
      "train loss:0.14151720387896113\n",
      "train loss:0.04340749170000634\n",
      "=== epoch:14, train acc:0.967, test acc:0.947 ===\n",
      "train loss:0.10699823291795994\n",
      "train loss:0.07391523972639524\n",
      "train loss:0.04848983169855843\n",
      "train loss:0.08928916331539831\n",
      "train loss:0.07522325629019579\n",
      "train loss:0.028712687259687505\n",
      "train loss:0.049968386166694516\n",
      "train loss:0.10496405781152945\n",
      "train loss:0.055080662659593535\n",
      "train loss:0.03689247805247757\n",
      "train loss:0.11655582848151941\n",
      "train loss:0.026411151398114535\n",
      "train loss:0.05124294922602929\n",
      "train loss:0.0562666530730222\n",
      "train loss:0.02621854455167717\n",
      "train loss:0.05285213964490529\n",
      "train loss:0.0787683138297015\n",
      "train loss:0.06244208121642417\n",
      "train loss:0.06604489074309611\n",
      "train loss:0.04824875943257644\n",
      "train loss:0.02691466040748458\n",
      "train loss:0.0661235831346452\n",
      "train loss:0.029522317067786887\n",
      "train loss:0.03212867111519983\n",
      "train loss:0.05096606022340171\n",
      "train loss:0.059717645028647945\n",
      "train loss:0.05297128645918045\n",
      "train loss:0.06987865406309966\n",
      "train loss:0.09026939611009117\n",
      "train loss:0.042973755503326154\n",
      "train loss:0.014488810534974826\n",
      "train loss:0.05348146993061722\n",
      "train loss:0.03871051561621999\n",
      "train loss:0.046552632053711455\n",
      "train loss:0.026376660310846477\n",
      "train loss:0.05150414668260853\n",
      "train loss:0.04078866945508176\n",
      "train loss:0.04154463886513203\n",
      "train loss:0.04468058488729895\n",
      "train loss:0.028991134521655134\n",
      "train loss:0.047766888854671466\n",
      "train loss:0.027668509911395637\n",
      "train loss:0.03976135779762877\n",
      "train loss:0.06344700053264252\n",
      "train loss:0.03542421116525166\n",
      "train loss:0.09733311954028456\n",
      "train loss:0.022579319638719033\n",
      "train loss:0.043737056681809536\n",
      "train loss:0.039254593639370505\n",
      "train loss:0.012534691331918075\n",
      "=== epoch:15, train acc:0.983, test acc:0.95 ===\n",
      "train loss:0.02724782497737952\n",
      "train loss:0.02595711663197191\n",
      "train loss:0.021740983245537405\n",
      "train loss:0.07948284069343013\n",
      "train loss:0.04243865404665272\n",
      "train loss:0.035529821368578685\n",
      "train loss:0.02771250814933254\n",
      "train loss:0.03645768886075242\n",
      "train loss:0.0720625965042219\n",
      "train loss:0.024844577560810707\n",
      "train loss:0.04951236587900798\n",
      "train loss:0.0879676682017012\n",
      "train loss:0.018550504523872947\n",
      "train loss:0.06567791639733404\n",
      "train loss:0.03873357846044188\n",
      "train loss:0.03193352002556308\n",
      "train loss:0.03455608876170444\n",
      "train loss:0.052171972745491464\n",
      "train loss:0.055841200889453074\n",
      "train loss:0.04758649583963475\n",
      "train loss:0.03291663464705472\n",
      "train loss:0.023731749303382524\n",
      "train loss:0.04482104275280598\n",
      "train loss:0.02031731310740652\n",
      "train loss:0.05800148555388155\n",
      "train loss:0.04526192781370507\n",
      "train loss:0.06665547984666231\n",
      "train loss:0.06342049224844672\n",
      "train loss:0.037689712322056736\n",
      "train loss:0.04460393047400913\n",
      "train loss:0.05175926274185081\n",
      "train loss:0.011206050774355618\n",
      "train loss:0.022603938321515547\n",
      "train loss:0.0378289960496883\n",
      "train loss:0.02607000646753783\n",
      "train loss:0.06618268499584147\n",
      "train loss:0.027152977684832\n",
      "train loss:0.013572803850630168\n",
      "train loss:0.04956340877218463\n",
      "train loss:0.04545615090306847\n",
      "train loss:0.05064058756660203\n",
      "train loss:0.05432878478339337\n",
      "train loss:0.1628593944281173\n",
      "train loss:0.02059564548610127\n",
      "train loss:0.041292381721001255\n",
      "train loss:0.028500523834455338\n",
      "train loss:0.06771930432585889\n",
      "train loss:0.06891779372384221\n",
      "train loss:0.03462476845878266\n",
      "train loss:0.036923822941757306\n",
      "=== epoch:16, train acc:0.987, test acc:0.957 ===\n",
      "train loss:0.03851666621705622\n",
      "train loss:0.030354035384732695\n",
      "train loss:0.04195262821736051\n",
      "train loss:0.03734548811557815\n",
      "train loss:0.03566189688920514\n",
      "train loss:0.019246347478969642\n",
      "train loss:0.10196118330474159\n",
      "train loss:0.024234146568756706\n",
      "train loss:0.0315879917491553\n",
      "train loss:0.022470209733007836\n",
      "train loss:0.022516676988043577\n",
      "train loss:0.052368979180164456\n",
      "train loss:0.04089416538955617\n",
      "train loss:0.026448215454059874\n",
      "train loss:0.025024998989297897\n",
      "train loss:0.08985399387302387\n",
      "train loss:0.029771826246261825\n",
      "train loss:0.028674663172333227\n",
      "train loss:0.016152071980694807\n",
      "train loss:0.06316499122077972\n",
      "train loss:0.0308050140650799\n",
      "train loss:0.03793646315082042\n",
      "train loss:0.029581341724888698\n",
      "train loss:0.04451282679539034\n",
      "train loss:0.07447995162779404\n",
      "train loss:0.029847914838380663\n",
      "train loss:0.0679550315685362\n",
      "train loss:0.05446153897274328\n",
      "train loss:0.02745221876853367\n",
      "train loss:0.021540641588644934\n",
      "train loss:0.023845093017924323\n",
      "train loss:0.04212950075977025\n",
      "train loss:0.028598060993991436\n",
      "train loss:0.029174291909252106\n",
      "train loss:0.015597931774763383\n",
      "train loss:0.017938904405998234\n",
      "train loss:0.011104195267919868\n",
      "train loss:0.012130448280188966\n",
      "train loss:0.012217098028623957\n",
      "train loss:0.025684524270132555\n",
      "train loss:0.03200050478814054\n",
      "train loss:0.0372299858905947\n",
      "train loss:0.09293697217092928\n",
      "train loss:0.04478911930053044\n",
      "train loss:0.026254069215432437\n",
      "train loss:0.05745774418668122\n",
      "train loss:0.04093576737164133\n",
      "train loss:0.02484027163712253\n",
      "train loss:0.021983925101655425\n",
      "train loss:0.02450090607114741\n",
      "=== epoch:17, train acc:0.986, test acc:0.954 ===\n",
      "train loss:0.038469886513751474\n",
      "train loss:0.026947323563429104\n",
      "train loss:0.023946581920229665\n",
      "train loss:0.035706221849049126\n",
      "train loss:0.06444768420255918\n",
      "train loss:0.028902912114101523\n",
      "train loss:0.02025469385497962\n",
      "train loss:0.027456991945708324\n",
      "train loss:0.03690695173168007\n",
      "train loss:0.06352321067389702\n",
      "train loss:0.03048826744879995\n",
      "train loss:0.012588250008721038\n",
      "train loss:0.08951490784736599\n",
      "train loss:0.04505364203766573\n",
      "train loss:0.05495830335672782\n",
      "train loss:0.08623250116499\n",
      "train loss:0.0761792409340149\n",
      "train loss:0.06526427462499264\n",
      "train loss:0.021025869268456853\n",
      "train loss:0.010925535383748622\n",
      "train loss:0.025741360716981368\n",
      "train loss:0.01666347729842807\n",
      "train loss:0.019013900442464704\n",
      "train loss:0.030017777580831514\n",
      "train loss:0.02499493528293716\n",
      "train loss:0.02123944704932024\n",
      "train loss:0.02917619440256176\n",
      "train loss:0.05046963777678011\n",
      "train loss:0.023767702903743592\n",
      "train loss:0.01930616722885844\n",
      "train loss:0.020220304121534013\n",
      "train loss:0.017224824368681113\n",
      "train loss:0.050328277454163815\n",
      "train loss:0.010579414981759024\n",
      "train loss:0.028600555437103353\n",
      "train loss:0.02271385220533317\n",
      "train loss:0.055018674244877296\n",
      "train loss:0.03878332181300229\n",
      "train loss:0.07570327185577704\n",
      "train loss:0.026952549143128713\n",
      "train loss:0.04670771182196485\n",
      "train loss:0.019840731210854706\n",
      "train loss:0.011180311119825809\n",
      "train loss:0.019096598879504458\n",
      "train loss:0.0060614051755191346\n",
      "train loss:0.0278583053610939\n",
      "train loss:0.02783004789697554\n",
      "train loss:0.042185892126287976\n",
      "train loss:0.015037078360276708\n",
      "train loss:0.03145837717281364\n",
      "=== epoch:18, train acc:0.991, test acc:0.953 ===\n",
      "train loss:0.007055503612573016\n",
      "train loss:0.0575967402735377\n",
      "train loss:0.020832795988154525\n",
      "train loss:0.0332531756550679\n",
      "train loss:0.011308148004865223\n",
      "train loss:0.026725641914697548\n",
      "train loss:0.017234983723659912\n",
      "train loss:0.014154663451222573\n",
      "train loss:0.018129329900560164\n",
      "train loss:0.031786786008893474\n",
      "train loss:0.03598835564324696\n",
      "train loss:0.027902777946281803\n",
      "train loss:0.019507201874918688\n",
      "train loss:0.02579716297143645\n",
      "train loss:0.017633842804526834\n",
      "train loss:0.011803575065942344\n",
      "train loss:0.034164947776048005\n",
      "train loss:0.046311961910584704\n",
      "train loss:0.010942257068581625\n",
      "train loss:0.07686322338614969\n",
      "train loss:0.025885744689975468\n",
      "train loss:0.023137917256400974\n",
      "train loss:0.012250856286495772\n",
      "train loss:0.03789657610608354\n",
      "train loss:0.06507621492014731\n",
      "train loss:0.019421930834040032\n",
      "train loss:0.02412653933522917\n",
      "train loss:0.004453385982951442\n",
      "train loss:0.029814896644300162\n",
      "train loss:0.038755987748277446\n",
      "train loss:0.030359182732243726\n",
      "train loss:0.02878918006107765\n",
      "train loss:0.04849416861576452\n",
      "train loss:0.007479651349468377\n",
      "train loss:0.029743023896709866\n",
      "train loss:0.01811591305076979\n",
      "train loss:0.01723783506465416\n",
      "train loss:0.031228360776858137\n",
      "train loss:0.015857324954172157\n",
      "train loss:0.00785351255242535\n",
      "train loss:0.013765821679596012\n",
      "train loss:0.012850093606288006\n",
      "train loss:0.010903062873093275\n",
      "train loss:0.012768902022311577\n",
      "train loss:0.024461608296943108\n",
      "train loss:0.035585848910396074\n",
      "train loss:0.014827394028611161\n",
      "train loss:0.012986027832335942\n",
      "train loss:0.014977429150448198\n",
      "train loss:0.011398802990776202\n",
      "=== epoch:19, train acc:0.993, test acc:0.961 ===\n",
      "train loss:0.02037853353438674\n",
      "train loss:0.05887582267845233\n",
      "train loss:0.027650431570313976\n",
      "train loss:0.008240615542438564\n",
      "train loss:0.019772213923975216\n",
      "train loss:0.029725682016981763\n",
      "train loss:0.02303804563415549\n",
      "train loss:0.022929657383424756\n",
      "train loss:0.008086101658525788\n",
      "train loss:0.02404112109897885\n",
      "train loss:0.03909045222342272\n",
      "train loss:0.019426616343190362\n",
      "train loss:0.013148517994109193\n",
      "train loss:0.016377350850925308\n",
      "train loss:0.018658649936472448\n",
      "train loss:0.01735655745401438\n",
      "train loss:0.05225969096311259\n",
      "train loss:0.005041657439613959\n",
      "train loss:0.02454651320529734\n",
      "train loss:0.006826816702964104\n",
      "train loss:0.013123041323296464\n",
      "train loss:0.015512151747766234\n",
      "train loss:0.03047458125837916\n",
      "train loss:0.021828277092834483\n",
      "train loss:0.019495336619439956\n",
      "train loss:0.014173918386243703\n",
      "train loss:0.009087526195842566\n",
      "train loss:0.008915524551149238\n",
      "train loss:0.014209928092624401\n",
      "train loss:0.0048455536988289485\n",
      "train loss:0.02021260402045774\n",
      "train loss:0.04537355438629411\n",
      "train loss:0.008728372192389852\n",
      "train loss:0.06481290657789698\n",
      "train loss:0.006977380774607079\n",
      "train loss:0.05932607874044543\n",
      "train loss:0.015002855778140827\n",
      "train loss:0.020372910824105434\n",
      "train loss:0.0205853454060437\n",
      "train loss:0.010850833228301318\n",
      "train loss:0.01617876673568966\n",
      "train loss:0.02495842431049969\n",
      "train loss:0.028779565802716532\n",
      "train loss:0.04167272047513564\n",
      "train loss:0.014734078298557931\n",
      "train loss:0.01266517490496413\n",
      "train loss:0.04002108242176097\n",
      "train loss:0.01903000296490269\n",
      "train loss:0.007320237806401446\n",
      "train loss:0.026922099670122112\n",
      "=== epoch:20, train acc:0.995, test acc:0.96 ===\n",
      "train loss:0.016650012220241896\n",
      "train loss:0.02457324952384454\n",
      "train loss:0.021337637408491764\n",
      "train loss:0.019776752991454802\n",
      "train loss:0.01567744009806673\n",
      "train loss:0.021126751841366084\n",
      "train loss:0.03286535087590937\n",
      "train loss:0.03467294368406161\n",
      "train loss:0.010363729480107265\n",
      "train loss:0.01628827966312284\n",
      "train loss:0.006254229031472796\n",
      "train loss:0.012673930225205337\n",
      "train loss:0.010720433624416232\n",
      "train loss:0.02060447419750029\n",
      "train loss:0.025822782214438535\n",
      "train loss:0.00466106316212956\n",
      "train loss:0.008588612774003732\n",
      "train loss:0.0362645894936636\n",
      "train loss:0.014382028965577767\n",
      "train loss:0.008121749496583362\n",
      "train loss:0.017489537544988554\n",
      "train loss:0.020809976500420966\n",
      "train loss:0.0177040898656158\n",
      "train loss:0.011593306562900656\n",
      "train loss:0.022365068085595304\n",
      "train loss:0.03400318042492067\n",
      "train loss:0.013771726384246208\n",
      "train loss:0.011109861829391434\n",
      "train loss:0.006345279137622033\n",
      "train loss:0.014419124963825714\n",
      "train loss:0.028065796892735063\n",
      "train loss:0.027001895309513694\n",
      "train loss:0.025113681617829123\n",
      "train loss:0.0405817259220381\n",
      "train loss:0.028161277248974576\n",
      "train loss:0.024785085138176913\n",
      "train loss:0.027440385939424883\n",
      "train loss:0.0098783355276952\n",
      "train loss:0.012605902917422786\n",
      "train loss:0.0074400698135640006\n",
      "train loss:0.01902857100720454\n",
      "train loss:0.02100024096361091\n",
      "train loss:0.006503188236478636\n",
      "train loss:0.014055708624589469\n",
      "train loss:0.016215027951032142\n",
      "train loss:0.023696475038592048\n",
      "train loss:0.006621085970228625\n",
      "train loss:0.007813696634001301\n",
      "train loss:0.01992021678440284\n",
      "=============== Final Test Accuracy ===============\n",
      "test acc:0.953\n",
      "Saved Network Parameters!\n"
     ]
    },
    {
     "data": {
      "text/plain": "<Figure size 432x288 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3de3xcdZ3/8ddnksn92qQtbdpCqQUpWFqoKCIqokIRpbisq6jrsquVFVZdF37Auov4c31s/aGuD10E+fnDOyJyFyogguCKCAVK79CCtE3SJmnS3K8z8/39cSbtNJmZTC4nk855Px+Pecy5zvnkdHo+c77nezHnHCIiElyhbAcgIiLZpUQgIhJwSgQiIgGnRCAiEnBKBCIiAadEICIScL4lAjO7zcyazWxLivVmZt8xs11mtsnMTvMrFhERSc3PO4IfAeenWb8aWBp/rQVu9jEWERFJwbdE4Jx7CmhLs8lFwE+c5xmgyszm+RWPiIgkl5/FY9cBexPm6+PL9o3c0MzW4t01UFpaevob3/jGaQlQRHJDe+8Q+zv7GYrGCOeFOKaiiKqS8LQdu6G9j1hCLw4hM+qqiqkqCeMcOBwxB845b95BjOFpd2i+IC+PovDEfr8///zzB5xzs5Oty2YisCTLkvZ34Zy7FbgVYNWqVW7Dhg1+xiUiU+y+Fxu48ZGXaWzvY35VMVefdyJrVtZN27Gvu2cztUPRQ8vC4Tz+7UNvShrDUDRGZ98Q7X1DdAy/er33/qEog5EYg9EYg5EYAwnTh14J8wPRGO0NHcyNjb60RYAD4/xb/v6dS7h29cR+CJvZ7lTrspkI6oGFCfMLgMYsxSIiPhm+EPfFL8QN7X1cd89mANasrCMSjdEzEKVrYIjugQjd/RHvPWG6qz9Cz0AEB+SFzHuZHZ5Os+zrD+84dOxhfUNR/vXezTy8ZT/tfYN09EW8i3/vID2D0ZF/wihmUJAXoiA/RGF+6NB0QX6I8PB0XoiKonwiSZLAsM+du3TU/kdM54coTJifW1E08X+INLKZCB4ArjSzO4C3AB3OuVHFQiJy9OobjPIfD21LeiH+4p0bj0gQYykpyCNkRiQWIxbDe59En5m9g1FebemmsjhMXVURJ80rp6q4gMriMJXF+VSVxKdLwvFlYYrDeRTkh8gPGWbJCjVGO2vd4zS0941aXldVzBffe8LE/4Ap5FsiMLNfAO8Cas2sHvgyEAZwzt0CrAcuAHYBvcBlfsUikm3ZLBqZruPHYo6/tPbw4p52Nu49yIt72tmxv4toiqt1zMEnzjyWssL8w6+iw+/lCfOlBfmEQqMvvM45ojFHdPh95Ms51tz0R5o6B0btW1dVzG+/+M4pPQfJXH3eiaMSXnE4j6vPO9H3Y2fKt0TgnPvoGOsdcIVfxxeZKcYqGpmO4197zyb6h2JTevz23kFe3NvOxj3t8feDdPZHACgvzOfUhVV89l1LuP3Pe2jtGRy1f11VMf96wUkTPj6AmZGfZ2kvZNetPimrF+Lhc5zNHwJjsaNtPAI9LJajyUAkytlff4LmrtG/SKuKw/zbhcvIDxmhkHnv5r2PKvsOGSGDvsEY3QNDdMXLznsGInQllqv3HznfMxBJehEGyDPjDXPKRv8KT/xlnjBdmJ/HruYuXoxf+P9yoAeAkMEJc8tZuaiKlQurWbmoiiWzyw79gh+ZCMG7EP9nioe1fsjqHdmNS6GnefTy0jlw9c7piQEws+edc6uSrlMikCDw80LQ3jvI7tZedrf1sretl92tPexu7WVPWy/7O/vx+79YYX6I8qJ8ShOKWMoTLu4/e2ZPyn3PO3nuqATSMxBJ+8C0tqzQu+jHL/zLF1RSWpi+cCHQF+IbKtOs6/D/+HHpEkE2HxaLTItkRTPX3rOJzr5B3nvyMUSijphzRGKOWMx7TyxjHp4ejMRobO9jd1sve1p72d3Ww57W3kPFIcNmlxeyaFYJZx5fw6KaEn789Osc7B0aFdfcikJ+9Zm3xR96jjhuihhKCo680JcW5lOQn75e+RM7WlI+rPz+J5JeF4jGHD2DR9ba6R2MsLi2lLqq4owflAJw41LW9DSzBqAI6AfuBx4b54U4MgihPO81HsmSQLrlExUZgK590Nl45Cud754Olnf477I8COXH5/PBQkfOL7sIVlw6tXGjRCA5rqN3iBse2DqqZkr/UIzrH9jG9Q9sG/dnhvOMBdUlLJxVwsqF1Rxb400fW1PColkllBQc+d/q8ufOp6ioddTn9FsNRTWvjfv44/U796nkx3c1QPLj54WMiqIwFUVT0Ogq3YXYOeg7CF37obsJupuhe3/8vSm+PD7d3x4PrhDCxVBQ6r2HiyFcOmJZSfxVnD62bfcnXIiTXHiPuEiHoLcVOvdBZ0PCxb7BSwA9LaM/v6As/fHnnQqxCMSi4GKHp2ORw/ORgfh8FPo7xz7fE6BEINNiOosGOvuHeGxbEw9u2scfdrYwFE1dNrPuQ286VD4/Vv30/LwQ8yqLmF9VTF6SGiypFA2MvginWz7VfDm+cxAdhKFeGOqDwd74dO/oZel8dTbERt8tES6Bsrnea/aJcPw7oXS2d9zE4wzGjzXUA4Pd3sV4sCe+rNebTufOv534OSieBRXzvVfdaVBR502Xzzs8XVSRvmjoktsmfvwppEQgvpuOWjM9AxEe2+5d/J98pYXBSIy6qmIuO2sx977YQEuSh7V1VcV85IxFU3J8hvqgr937ddsff+9rP/wrNpWffsj7pReLJvwSjL/HYiPmM6tvPy7fXj6OjZ3363T4Au+mIJ63XQllx0DZHO+iXx6fLijzWm1NhXQX4n98OuFXeDTh3yKS/N8l8eI/1t3GUUSJQHyXqmXnuod3cNGK+eMrb078jMEoj+9o5qHNjTy+o5n+oRhzKwr5+FuO5f3L53HaoirMjC++9P5xF40QGfRu97sSiwH2QV/biAt+fDo6OtFkpL/9yCKI/IKE+XwIhY5cn7RnljG0p+xZABadOb7PChcdWewSLoGCkvTL/mtZ6s97zw3jO/5Um3uy/8conZP6YfUMoUQgUy4SjfFSfTt/2HmAP+w8wL6O/qTb7e/o56TrH6auqpi66hLqqopZUO29vGXFzCkvOqIYpn8oypOvtPDgpn38bnsTvYNRassK+fCqhVy4fD6rjq0e1fAobdHIS3cceaEfnk72HzdcAiU1UFQFxVVQuzQ+Xe3NF1cfXndouhq+fmzqk/Xpx8c+oZO16Y7U6z70ff+Pn23ZvhBPYxXRiVIikElzzrG7tZc/7DrAH15p4U+vttI1EMEMltdVUlaYT/dAZNR+lcX5/PXpC6k/2EdDex9bGjpoG1HnPZxnzKv0EkN5UT5Pv9pK90CE6pIwF62o4wPL5/GW42uSl9k7Bwf/kj74ez/jvRdVHS7XnXfq4dv/ivne8vJ5UFQ5dcUVQaIL8YynRCAT0tE7xNOvHuCpnQf4n10t7G3zqifWVRVz4anzePsbZvO2JTVUlxbQ/5/HU2RJimYKaii68Miimd7BCA0H+6hv76MhniDqD/bRcLCXxqY+Vp9yDBeeOp+3LakhnDei2mRkEPZvgj3PwN5nYM+fx64i+E8veBf5gpJJnY+0sn0hzPbxdSGe8ZQIAmCyNXb6h6Ic6B5gb1sfT7/qFfdsqm8n5qCsMJ8zl9Tw6bOP5+ylszmupmRUmf94aq2UFOSzdG45S+eWjx1Y30HY+2z8wv9naHgBIvH68lXHwpJzYOFb4KEvpv6MmiVjH2eysn0hzPbxZcZTIshx/f95PGsGWkc15ul/uIbez+2gpWuAlq4Bmrv6E6a995buAZo7+49oMBUyWLGwiivfvZR3LK3l1IVVo3+ZD4vFvDL3dJ65xXtAmlcI+YWQVzD6fXjaxaBx4+Ff+y3b40HlwzHLYdVl3oV/0Vu92ifD0iUCEVEiyGWxmEv7a/yNX/3t6OXhEHPKi5hTXsjSOWW8bUkNc8oLmV1eyNyKIlYuqqayeEQjo942aN115OvALmh7FSLJHxQf8vA14//DCith4ZvhlL+CRW+ButO9hkSpZLtoRGSGUyLIEYORGDubu9ja2Mm2+Gv7vk42p3m2+cM376W8pIiK4iLKS4uoKi2mqLAAC+V5o1lbxHsPDUGoH2Jt8Jc/Hb7QD1/0+xKGpg7lQ/VxUPMGr2im5g3w4BdSB/G//uLVTY8OeGX8R7wPeI2Wht9dzKvuN/skr1plplQ0IpKWEsFRqKt/iO37utjW2MHWxk62Nnays7nrUAva4nAeJ80r52+WV8Lm1J9zzuYJ/BofVj7Pu8gvu8h7r3mDV52yahHkjbhjSJcISmZNPAYRmRJKBEeBrv4hHtvexGPbm9nS0MHu1sPN9mtKC1g2v4J3nHA8y+ZXcMqcAo5t+yN5W34O2x5J/8GffSZJq8rIiOmE1q3gXehnLYHCMfpQSaSiGZEZTYlghuodjPDY9mYefKmR38e7TJhbUchpi6q55LQFnFxXwcnzK5lTXoi5GLz+B9j8XVj/axjo8PplOf3v4Nk0DYbmTG5QkIypaEZkRlMimEH6h6I8saPZazW7o4n+oRhzygu59IxFfODUeaxcmNBq1jlofBGevgu23O312FhQDid9AN50CSx+J+Tlw9Z79WtcRNJSIsiygUiUJ1/2ukx4LN5lQk1pAZecvoALl8/nzcfNOrLVbOursPlX3qt1F4TCsPR9sPyv4YTzR3eEpV/jIjIGJYJpMLJB1z+/Zyk1ZYX8elMjv93aRNdAhKqSMBetmM/73zSftx4/i/zEuvm9bd6v/o23Q+MLgMFxb4e3fQ6WfdDrz0ZEZIKUCHyWrAvmq+7aBEB5UT7nnXIMFy6fx1lvqD2yYVYsCq89AS/+HHY86FWfnPsmeN9/wMkfgsqZM/C1iBzdlAh8duMjL4/qghlgVmkBf7ru3RTmjxh2r/VV75f/S7/wWuUWV8Ppl8HKj3mdoYmITDElAh8NRWNJx4oFONgzeDgJDHR7Q+a9+DPY87Q3JN6Sc+G8r8GJF3jdK4iI+ESJwCeb6tu55u7NPFf4j8y2jlHrW6mC3b/win623usNtTdrCZx7PZz6Ua/7YxGRaaBEMMV6ByN889FX+OEf/0JtWWHSJABQQzv8cLU3JN8pF8PKT3gdpqm/exGZZkoEU+jJV1r40r2bqT/Yx8fesohrVr8R1qXZYc3NcNIHx9dKV0RkiikRTIHW7gG++uA27tvYyJLZpdz5mTM5Y/Esr9FXOisunZ4ARUTSUCKYBOcc977YwFcf3Eb3QITPnbuUK85ZQmFeCHY8BE/dmO0QRUTGpEQwQXvbevnXezfzh50HOG1RFev+ajknzC6F7ffDU9+Api1ed8wiIjOcEsE4RaIxfvjH1/nWb18hL2R89aKT+dib6whtuxd+9Q048DLULIWLvw+nXALfOkl9/YjIjKZEMA5bGjq47p7NbG7o4D0nzeWrHziBebsfgO99E9pegznL4JLbYNkaCMXbCKivHxGZ4ZQIMrStsZOLbvojs0oLuOUjp3De0O+wn1wG7Xu88XL/5mdw4vvHN3KWiMgMoESQoZfq28mPDfDwma9S8/jnvO4f6k6HC77h9f6p+v8icpRSIshQT/PrPFn4z9Q8dRAWnQkf/C4sebcSgIgc9ZQIMlS1/48cYwfhI7d7/f8oAYhIjvC1QNvMzjezl81sl5ldm2R9pZn92sxeMrOtZnaZn/FMRl73fm9iyblKAiKSU3xLBGaWB9wErAaWAR81s2UjNrsC2OacOxV4F/BNMyvwK6bJKOzdT1eoEsJF2Q5FRGRK+XlHcAawyzn3mnNuELgDuGjENg4oNzMDyoA2IOJjTBNWNtRCV8HsbIchIjLl/EwEdcDehPn6+LJE/w2cBDQCm4HPO+diIz/IzNaa2QYz29DS0uJXvCkNRWPURA/QX3zMtB9bRMRvfiaCZAXpI3thOw/YCMwHVgD/bWYVo3Zy7lbn3Crn3KrZs6f/V3lz1wBz7SCxMiUCEck9fiaCemBhwvwCvF/+iS4D7nGeXcBfgDf6GNOENB/soNY6yavUYDEiknv8TATPAUvNbHH8AfBHgAdGbLMHOBfAzOYCJwKv+RjThHQ2eyVcBbMWjrGliMjRx7d2BM65iJldCTwC5AG3Oee2mtnl8fW3AF8FfmRmm/GKkq5xzh3wK6aJ6mnxEkH5bCUCEck9vjYoc86tB9aPWHZLwnQj8D4/Y5gKQ+0NAJTNXpTlSEREpp56SMtEl/doI1Q5stKTiMjRT4kgA+Ge/fRTCEWV2Q5FRGTKKRFkoKS/mY7wbHUtISI5SYkgA5WRFnoLNaKYiOQmJYIx9A5GmO1aGSxVYzIRyU1KBGNo6uhjDgehfF62QxER8YUSwRjamhsosCj51QuyHYqIiC+UCMbQFW9MVlKjxmQikpuUCMbQ3+Ylgoo5akwmIrlJiWAMsXavMVlJre4IRCQ3KRGMwbr3ESWElc3NdigiIr5QIhhDUV8T7aFZEMrLdigiIr5QIhhD2UCzhqgUkZymRJCGc47q6AH6i1UsJCK5S4kgjY6+IebQRrRMjclEJHcpEaTRfKCVCutT99MiktOUCNJob9oNQGG1EoGI5C4lgjR6W73GZBqZTERymRJBGpG2egAqjzk2y5GIiPhHiSANFx+islAdzolIDlMiSCO/Zz9dVgbh4myHIiLiGyWCNEr6m2jPV2MyEcltSgRpVAwdoEdDVIpIjlMiSCEac9TEWhkqUatiEcltSgQpHOjoppYOXPn8bIciIuIrJYIU2pr2EDJHfpUak4lIblMiSKGr2WtMVlyrqqMiktuUCFLoOzREpRqTiUhuUyJIIdbeAEDV3OOyG4iIiM+UCFIIde9jgDB5pbOyHYqIiK+UCFIo6N1PW6gWzLIdioiIr5QIUigbaKFTQ1SKSAAoEaRQFT1Af5FaFYtI7lMiSKJ/MMIcpyEqRSQYfE0EZna+mb1sZrvM7NoU27zLzDaa2VYze9LPeDLV2rKPQhvCKtSqWERyX75fH2xmecBNwHuBeuA5M3vAObctYZsq4HvA+c65PWY2I8pi2pt2UwcUzFJjMhHJfX7eEZwB7HLOveacGwTuAC4asc2lwD3OuT0AzrlmH+PJWE/LHkBDVIpIMPiZCOqAvQnz9fFliU4Aqs3s92b2vJn9bbIPMrO1ZrbBzDa0tLT4FO5hgweHG5OpVbGI5D4/E0GyCvhuxHw+cDrwfuA84N/N7IRROzl3q3NulXNu1ezZ01Cls3MfMWdU1KrDORHJfRklAjO728zeb2bjSRz1wMKE+QVAY5JtHnbO9TjnDgBPAaeO4xi+yO/ZR1uoCssvyHYoIiK+y/TCfjNeef5OM1tnZm/MYJ/ngKVmttjMCoCPAA+M2OZ+4GwzyzezEuAtwPYMY/JNUX8T7Xm12Q5DRGRaZJQInHOPOec+BpwGvA781syeNrPLzCycYp8IcCXwCN7F/U7n3FYzu9zMLo9vsx14GNgEPAv8wDm3ZbJ/1GRpiEoRCZKMq4+aWQ3wceATwIvAz4G3A58E3pVsH+fcemD9iGW3jJi/EbhxPEH7yTlHTfQArSWrsh2KiMi0yCgRmNk9wBuBnwIfcM7ti6/6pZlt8Cu4bOju7qTSenDlalUsIsGQ6R3BfzvnHk+2wjmXUz+d2/btphw0RKWIBEamD4tPircCBsDMqs3ssz7FlFVdzbsBKKpZOMaWIiK5IdNE8GnnXPvwjHPuIPBpf0LKrr7WegAq5qhVsYgEQ6aJIGR2eISWeD9COVnJPtLhNXWYNU+tikUkGDJ9RvAIcKeZ3YLXOvhyvGqfOSfU1UgnJVSUVY29sYhIDsg0EVwDfAb4R7yuIx4FfuBXUNnkDVFZQ0W2AxERmSYZJQLnXAyvdfHN/oaTfaUDzXSFNUSliARHpn0NLTWzu8xsm5m9NvzyO7hsqI4coLd4brbDEBGZNpk+LP4h3t1ABDgH+Ale47KcEosMMcsdJFp6TLZDERGZNpkmgmLn3O8Ac87tds7dALzbv7Cyo72lgTxzWLmGqBSR4Mj0YXF/vAvqnWZ2JdAA5FyvbO1NrzMLDVEpIsGS6R3BF4AS4HN4A8l8HK+zuZzS0+INqFYyW62KRSQ4xrwjiDce+7Bz7mqgG7jM96iyZOCg16pYQ1SKSJCMeUfgnIsCpye2LM5ZHY0MuHxq5+gZgYgER6bPCF4E7jezXwE9wwudc/f4ElWW5PXs54DNoi4/42EaRESOeple8WYBrRxZU8gBOZUIhoeoVAfUIhIkmbYsztnnAokqBpupLzox22GIiEyrTEco+yHeHcARnHN/P+URZYtz1MRaea1ErYpFJFgyLRp6MGG6CLgYaJz6cLJnqOcgRQwS0xCVIhIwmRYN3Z04b2a/AB7zJaIsObj/deYAeRqiUkQCJtMGZSMtBXJqCK/O+BCVxWpVLCIBk+kzgi6OfEawH2+MgpwxPERluYaoFJGAybRoqNzvQLIt2t4AwKxj1KpYRIIl0/EILjazyoT5KjNb419Y08+6GjngKphVXpbtUEREplWmzwi+7JzrGJ5xzrUDX/YnpOwI9zbRGqolFMr9njRERBJlmgiSbZdT/TCUDjTToSEqRSSAMk0EG8zsW2a2xMyON7P/Ap73M7DpVhVpoa8o54ZYEBEZU6aJ4J+AQeCXwJ1AH3CFX0FNu6F+Kl2XhqgUkUDKtNZQD3Ctz7FkTV/bXooBV6Hup0UkeDKtNfRbM6tKmK82s0f8C2t6te/fA0BBtRqTiUjwZFo0VBuvKQSAc+4gOTRmcfcBLxGU1mqIShEJnkwTQczMDjW5NbPjSNIb6dFqsM0bq7hSjclEJIAyrQL6JeB/zOzJ+Pw7gLX+hDT9Yh376HZFzK5V9VERCZ6M7giccw8Dq4CX8WoO/QtezaGckNezj2ZmUV6YU00jREQykunD4k8Bv8NLAP8C/BS4IYP9zjezl81sl5mlrHVkZm82s6iZXZJZ2FOruK+Jg/m1mKlVsYgET6bPCD4PvBnY7Zw7B1gJtKTbwczygJuA1cAy4KNmtizFdl8HslYLqXywhe6CnHn2LSIyLpkmgn7nXD+AmRU653YAYw3uewawyzn3mnNuELgDuCjJdv8E3A00ZxjL1IpFqYq1MaghKkUkoDJNBPXxdgT3Ab81s/sZe6jKOmBv4mfElx1iZnV4w17eku6DzGytmW0wsw0tLWlvRMbNdTeTT5RomYaoFJFgyrRl8cXxyRvM7AmgEnh4jN2SFbiPrHL6beAa51w0Xfm8c+5W4FaAVatWTWm11Z6WvZQBoUq1KhaRYBp3NRnn3JNjbwV4dwCJLbQWMPouYhVwRzwJ1AIXmFnEOXffeOOaqI7mPZQBxTVqTCYiweRnfcnngKVmthhoAD4CXJq4gXNu8fC0mf0IeHA6kwBAb6vXqrhsthqTiUgw+ZYInHMRM7sSrzZQHnCbc26rmV0eX5/2ucB0ibQ3MOTyqJlTN/bGIiI5yNcWVM659cD6EcuSJgDn3N/5GUsq1rmPZqqYU1mcjcOLiGRdprWGcla4dz8HrIaicF62QxERyYrAJ4KSgWY68tXHkIgEV7ATgXNUDbXQV6xWxSISXMFOBAOdFNPPUImGqBSR4Ap0Ioh2xJs1aIhKEQmwQCeCzmavDUG4WlVHRSS4Ap0Ielq8RFBSs2iMLUVEclegE0F/Wz0AlXPUvYSIBFegE0Gso4E2V8bcmqpshyIikjWBTgR53fvY72qoKS3IdigiIlkT6ERQ1NdEW14N+XmBPg0iEnCBvgKWDbbQU6BWxSISbMFNBJEBKmPt9BWrMZmIBFtwE0HXfgBiGqJSRAIusIlgsN2rOpqnISpFJOACmwg6m7zGZIUaolJEAi6wiaCvdS8AZbOVCEQk2AKbCIYONtDnCqitVRfUIhJsgU0EdDWyz81iboWGqBSRYAtsIgj37KeZWVSVhLMdiohIVgU2EZQMNNMZrsXMsh2KiEhWBTMRxGJUDh2gt1DPB0REgpkIelvJJ8JgqRqTiYgEMxF0xYeoLFdjMhGRQCaC4TYE4SolAhGRQCaCrvhYxSWzNUSliEggE8FAWz1RZ1TUatB6EZFAJoJYRwMtVDG3qjTboYiIZF0gE0Goez/73SzmVhRlOxQRkawLZCIo7GvigNVQWpif7VBERLIukImgfLCZLg1RKSICBDERDHRTHOuhv3hutiMREZkRgpcIuvYBGqJSRGRY4BJBrL0BAKtQYzIREfA5EZjZ+Wb2spntMrNrk6z/mJltir+eNrNT/YwHoKc1PkTlrAV+H0pE5KjgWyIwszzgJmA1sAz4qJktG7HZX4B3OueWA18FbvUrnmG9B7xB68vUqlhEBPD3juAMYJdz7jXn3CBwB3BR4gbOuaedcwfjs88Avv9MHzpYT4croXZWtd+HEhE5KviZCOqAvQnz9fFlqfwD8JtkK8xsrZltMLMNLS0tk4uqq5F9roZjKtWYTEQE/E0EyYb+ckk3NDsHLxFck2y9c+5W59wq59yq2bMnV/8/3NNEk6tmdlnhpD5HRCRX+JkI6oGFCfMLgMaRG5nZcuAHwEXOuVYf4wGguL+Jg/m1FOQHrsKUiEhSfl4NnwOWmtliMysAPgI8kLiBmS0C7gE+4Zx7xcdYPNEhyiJt9BSqMZmIyDDfOttxzkXM7ErgESAPuM05t9XMLo+vvwW4HqgBvhcfRD7inFvlV0x0NxHCMVSiRCAiMszXXtecc+uB9SOW3ZIw/SngU37GcIROr2TKlatVsYjIsEB1vxlpbyAfyK9WYzKRoBkaGqK+vp7+/v5sh+KroqIiFixYQDgcznifQCWCngN7qQRKahaOua2I5Jb6+nrKy8s57rjjiBdF5xznHK2trdTX17N48eKM9wtU1Zn+1r0MuDBVNXpGIBI0/f391NTU5GwSADAzampqxn3XE6hEEO1oZL+rZm5lcbZDEZEsyOUkMGwif2OgEkGoex/70RCVIiKJApUICnv30+yqqSktyHYoIjLD3fdiA2ete5zF1z7EWese574XGyb1ee3t7Xzve98b934XXHAB7e3tkzr2WIKTCJyjbLCFzoI5hEK5f3soIhN334sNXHfPZhra+3BAQ3sf192zefPkTr8AAAvzSURBVFLJIFUiiEajafdbv349VVVVEz5uJnK/1tCNS6GnGYAw8LHo/XBDJZTOgat3Zjc2EcmKr/x6K9saO1Ouf3FPO4PR2BHL+oai/K+7NvGLZ/ck3WfZ/Aq+/IGTU37mtddey6uvvsqKFSsIh8OUlZUxb948Nm7cyLZt21izZg179+6lv7+fz3/+86xduxaA4447jg0bNtDd3c3q1at5+9vfztNPP01dXR33338/xcWTf+aZ+3cE8SSQ8XIRCbyRSWCs5ZlYt24dS5YsYePGjdx44408++yzfO1rX2Pbtm0A3HbbbTz//PNs2LCB73znO7S2ju56befOnVxxxRVs3bqVqqoq7r777gnHkyj37whEREZI98sd4Kx1j9PQ3jdqeV1VMb/8zJlTEsMZZ5xxRF3/73znO9x7770A7N27l507d1JTU3PEPosXL2bFihUAnH766bz++utTEkvu3xGIiIzT1eedSHE474hlxeE8rj7vxCk7Rmlp6aHp3//+9zz22GP86U9/4qWXXmLlypVJ2wIUFh7uPj8vL49IJDIlseiOQERkhDUrvTG0bnzkZRrb+5hfVczV5514aPlElJeX09XVlXRdR0cH1dXVlJSUsGPHDp555pkJH2cilAhERJJYs7JuUhf+kWpqajjrrLM45ZRTKC4uZu7cwz0cnH/++dxyyy0sX76cE088kbe+9a1TdtxM5HwiaHGVzLaO5MuzEI+IBNftt9+edHlhYSG/+U3SkXoPPQeora1ly5Yth5ZfddVVUxZXzieCNcU/SvnQ549ZiEdEZKbJ+YfF0/HQR0TkaJb7dwQ+PPQREcklOZ8IYOof+oiI5JKcLxoSEZH0lAhERAIuEEVDIiLjktBZ5REm0Vlle3s7t99+O5/97GfHve+3v/1t1q5dS0lJyYSOPRbdEYiIjORDZ5UTHY8AvETQ29s74WOPRXcEIhI8v7kW9m+e2L4/fH/y5ce8CVavS7lbYjfU733ve5kzZw533nknAwMDXHzxxXzlK1+hp6eHD3/4w9TX1xONRvn3f/93mpqaaGxs5JxzzqG2tpYnnnhiYnGnoUQgIjIN1q1bx5YtW9i4cSOPPvood911F88++yzOOT74wQ/y1FNP0dLSwvz583nooYcArw+iyspKvvWtb/HEE09QW1vrS2xKBCISPGl+uQPe4FWpXPbQpA//6KOP8uijj7Jy5UoAuru72blzJ2effTZXXXUV11xzDRdeeCFnn332pI+VCSUCEZFp5pzjuuuu4zOf+cyodc8//zzr16/nuuuu433vex/XX3+97/HoYbGIyEilc8a3PAOJ3VCfd9553HbbbXR3dwPQ0NBAc3MzjY2NlJSU8PGPf5yrrrqKF154YdS+ftAdgYjISD6MZ57YDfXq1au59NJLOfNMb7SzsrIyfvazn7Fr1y6uvvpqQqEQ4XCYm2++GYC1a9eyevVq5s2b58vDYnPOTfmH+mnVqlVuw4YN2Q5DRI4y27dv56STTsp2GNMi2d9qZs8751Yl215FQyIiAadEICIScEoEIhIYR1tR+ERM5G9UIhCRQCgqKqK1tTWnk4FzjtbWVoqKisa1n2oNiUggLFiwgPr6elpaWrIdiq+KiopYsGDBuPZRIhCRQAiHwyxevDjbYcxIvhYNmdn5Zvayme0ys2uTrDcz+058/SYzO83PeEREZDTfEoGZ5QE3AauBZcBHzWzZiM1WA0vjr7XAzX7FIyIiyfl5R3AGsMs595pzbhC4A7hoxDYXAT9xnmeAKjOb52NMIiIygp/PCOqAvQnz9cBbMtimDtiXuJGZrcW7YwDoNrOXJxhTLXBggvtOh5keH8z8GBXf5Ci+yZnJ8R2baoWficCSLBtZbyuTbXDO3QrcOumAzDakamI9E8z0+GDmx6j4JkfxTc5Mjy8VP4uG6oGFCfMLgMYJbCMiIj7yMxE8Byw1s8VmVgB8BHhgxDYPAH8brz30VqDDObdv5AeJiIh/fCsacs5FzOxK4BEgD7jNObfVzC6Pr78FWA9cAOwCeoHL/IonbtLFSz6b6fHBzI9R8U2O4pucmR5fUkddN9QiIjK11NeQiEjAKRGIiARcTiaCmdy1hZktNLMnzGy7mW01s88n2eZdZtZhZhvjL/9Hrz7y+K+b2eb4sUcNB5fl83diwnnZaGadZvaFEdtM+/kzs9vMrNnMtiQsm2VmvzWznfH36hT7pv2++hjfjWa2I/5veK+ZVaXYN+33wcf4bjCzhoR/xwtS7Jut8/fLhNheN7ONKfb1/fxNmnMup154D6ZfBY4HCoCXgGUjtrkA+A1eO4a3An+exvjmAafFp8uBV5LE9y7gwSyew9eB2jTrs3b+kvxb7weOzfb5A94BnAZsSVj2f4Br49PXAl9P8Tek/b76GN/7gPz49NeTxZfJ98HH+G4ArsrgO5CV8zdi/TeB67N1/ib7ysU7ghndtYVzbp9z7oX4dBewHa819dFkpnQNci7wqnNudxaOfQTn3FNA24jFFwE/jk//GFiTZNdMvq++xOece9Q5F4nPPoPXjicrUpy/TGTt/A0zMwM+DPxiqo87XXIxEaTqtmK82/jOzI4DVgJ/TrL6TDN7ycx+Y2YnT2tgXuvuR83s+Xj3HiPNiPOH1zYl1X++bJ6/YXNdvF1M/H1Okm1myrn8e7y7vGTG+j746cp40dVtKYrWZsL5Oxtocs7tTLE+m+cvI7mYCKasaws/mVkZcDfwBedc54jVL+AVd5wKfBe4bzpjA85yzp2G1zvsFWb2jhHrZ8L5KwA+CPwqyepsn7/xmAnn8ktABPh5ik3G+j745WZgCbACr/+xbybZJuvnD/go6e8GsnX+MpaLiWDGd21hZmG8JPBz59w9I9c75zqdc93x6fVA2Mxqpys+51xj/L0ZuBfv9jvRTOgaZDXwgnOuaeSKbJ+/BE3DRWbx9+Yk22T7u/hJ4ELgYy5eoD1SBt8HXzjnmpxzUedcDPi/KY6b7fOXD3wI+GWqbbJ1/sYjFxPBjO7aIl6e+P+A7c65b6XY5pj4dpjZGXj/Tq3TFF+pmZUPT+M9UNwyYrOZ0DVIyl9h2Tx/IzwAfDI+/Ung/iTbZPJ99YWZnQ9cA3zQOdebYptMvg9+xZf43OniFMfN2vmLew+wwzlXn2xlNs/fuGT7abUfL7xaLa/g1Sb4UnzZ5cDl8WnDGzTnVWAzsGoaY3s73q3rJmBj/HXBiPiuBLbi1YB4BnjbNMZ3fPy4L8VjmFHnL378ErwLe2XCsqyeP7yktA8YwvuV+g9ADfA7YGf8fVZ82/nA+nTf12mKbxde+frw9/CWkfGl+j5MU3w/jX+/NuFd3OfNpPMXX/6j4e9dwrbTfv4m+1IXEyIiAZeLRUMiIjIOSgQiIgGnRCAiEnBKBCIiAadEICIScEoEIj6L94b6YLbjEElFiUBEJOCUCETizOzjZvZsvN/475tZnpl1m9k3zewFM/udmc2Ob7vCzJ5J6Mu/Or78DWb2WLzDuxfMbEn848vM7K54//8/T2j5vM7MtsU/5xtZ+tMl4JQIRAAzOwn4G7wOwlYAUeBjQClen0anAU8CX47v8hPgGufccrzWr8PLfw7c5LwO796G1xoVvF5mvwAsw2ttepaZzcLrOuHk+Of8h79/pUhySgQinnOB04Hn4iNNnYt3wY5xuEOxnwFvN7NKoMo592R8+Y+Bd8T7lKlzzt0L4Jzrd4f78HnWOVfvvA7UNgLHAZ1AP/ADM/sQkLS/HxG/KRGIeAz4sXNuRfx1onPuhiTbpeuTJVmXyMMGEqajeCODRfB6orwbb9Cah8cZs8iUUCIQ8fwOuMTM5sCh8YaPxfs/ckl8m0uB/3HOdQAHzezs+PJPAE86b1yJejNbE/+MQjMrSXXA+JgUlc7rKvsLeP3ui0y7/GwHIDITOOe2mdm/4Y0kFcLrZfIKoAc42cyeBzrwniOA1630LfEL/WvAZfHlnwC+b2b/O/4Zf53msOXA/WZWhHc38c9T/GeJZES9j4qkYWbdzrmybMch4icVDYmIBJzuCEREAk53BCIiAadEICIScEoEIiIBp0QgIhJwSgQiIgH3/wHfd1hTobeerQAAAABJRU5ErkJggg==\n"
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mnist import load_mnist\n",
    "from common.trainer import Trainer\n",
    "\n",
    "# 读入数据\n",
    "(x_train, t_train), (x_test, t_test) = load_mnist(flatten=False)\n",
    "\n",
    "# 处理花费时间较长的情况下减少数据\n",
    "x_train, t_train = x_train[:5000], t_train[:5000]\n",
    "x_test, t_test = x_test[:1000], t_test[:1000]\n",
    "\n",
    "max_epochs = 20\n",
    "\n",
    "network = SimpleConvNet(input_dim=(1,28,28),\n",
    "                        conv_param = {'filter_num': 30, 'filter_size': 5, 'pad': 0, 'stride': 1},\n",
    "                        hidden_size=100, output_size=10, weight_init_std=0.01)\n",
    "\n",
    "trainer = Trainer(network, x_train, t_train, x_test, t_test,\n",
    "                  epochs=max_epochs, mini_batch_size=100,\n",
    "                  optimizer='Adam', optimizer_param={'lr': 0.001},\n",
    "                  evaluate_sample_num_per_epoch=1000)\n",
    "trainer.train()\n",
    "\n",
    "# 保存参数\n",
    "network.save_params(\"params.pkl\")\n",
    "print(\"Saved Network Parameters!\")\n",
    "\n",
    "# 绘制图形\n",
    "markers = {'train': 'o', 'test': 's'}\n",
    "x = np.arange(max_epochs)\n",
    "plt.plot(x, trainer.train_acc_list, marker='o', label='train', markevery=2)\n",
    "plt.plot(x, trainer.test_acc_list, marker='s', label='test', markevery=2)\n",
    "plt.xlabel(\"epochs\")\n",
    "plt.ylabel(\"accuracy\")\n",
    "plt.ylim(0, 1.0)\n",
    "plt.legend(loc='lower right')\n",
    "plt.show()\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-10-04T04:37:45.370634800Z",
     "start_time": "2023-10-04T04:30:28.295611200Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
